package com.i72.basic;


import lombok.extern.slf4j.Slf4j;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisCluster;
import redis.clients.jedis.JedisClusterInfoCache;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.Pipeline;
import redis.clients.jedis.exceptions.JedisAskDataException;
import redis.clients.jedis.exceptions.JedisConnectionException;
import redis.clients.jedis.exceptions.JedisException;
import redis.clients.jedis.exceptions.JedisMovedDataException;
import redis.clients.jedis.exceptions.JedisRedirectionException;
import redis.clients.util.JedisClusterCRC16;
import redis.clients.util.SafeEncoder;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

/**
 * @author jiangj
 * @version 1.0.0
 * @ClassName PipelineCluster.java
 * @Description TODO
 * @createTime 2022年01月21日 22:41:00
 */
@Slf4j
public class PipelineCluster extends JedisCluster {
    private final JedisClusterInfoCache cache;

    public PipelineCluster(Set<HostAndPort> nodes, int timeout, final GenericObjectPoolConfig poolConfig) {
        this(nodes, timeout, DEFAULT_MAX_REDIRECTIONS, poolConfig);
    }

    public PipelineCluster(Set<HostAndPort> nodes, final GenericObjectPoolConfig poolConfig) {
        this(nodes, DEFAULT_TIMEOUT, DEFAULT_MAX_REDIRECTIONS, poolConfig);
    }

    public PipelineCluster(Set<HostAndPort> jedisClusterNode, int timeout, int maxAttempts, final GenericObjectPoolConfig poolConfig) {
        this(jedisClusterNode, timeout, timeout, maxAttempts, poolConfig);
    }

    public PipelineCluster(Set<HostAndPort> jedisClusterNode, int connectionTimeout, int soTimeout, int maxAttempts, final GenericObjectPoolConfig poolConfig) {
        super(jedisClusterNode, connectionTimeout, soTimeout, maxAttempts, poolConfig);
        JedisClusterInfoCache cache = null;

        Class clazz = this.connectionHandler.getClass();
        while (clazz != null) {
            if (clazz.getName().equals("java.lang.Object")) {
                break;
            }
            try {
                Field field = clazz.getDeclaredField("cache");
                field.setAccessible(true);
                cache = (JedisClusterInfoCache) field.get(this.connectionHandler);
                if (cache != null) {
                    break;
                }
            } catch (NoSuchFieldException | IllegalAccessException ignored) {} finally {
                clazz = clazz.getSuperclass();
            }
        }
        if (cache == null) {
            throw new JedisException("JedisClusterInfoCache is null");
        }
        this.cache = cache;
    }

    /**
     * 批量获取
     * @param keys key集合
     * @return Map<String, String> key对应的值不存在时不会放入map
     */
    public Map<String, String> mget(final List<String> keys) {
        Map<String, String> result = new HashMap<>();
        if (keys == null || keys.isEmpty()) {
            return result;
        }
        return execute(keys, (pipeline, pipelineKeys) -> {
            for (String key : pipelineKeys) {
                pipeline.get(key);
            }
        }, (resultMap) -> {
            if (resultMap == null || resultMap.isEmpty()) {
                return result;
            }
            for (Map.Entry<String, Object> entry : resultMap.entrySet()) {
                String key = entry.getKey();
                Object object = entry.getValue();
                if (object == null) {
                    continue;
                }
                if (isNeedRetries(object)) {
                    String value = get(key);
                    if (value != null) {
                        result.put(key, value);
                    }
                } else {
                    result.put(key, object.toString());
                }
            }
            return result;
        });
    }

    /**
     * 批量设置过期时间
     * @param keyTimeMap key和过期时间map
     * @return Map<String, Long> key和对应的操作结果，若操作结果为null则不会放入map
     */
    public Map<String, Long> mexpire(final Map<String, Integer> keyTimeMap) {
        Map<String, Long> result = new HashMap<>();
        if (keyTimeMap == null || keyTimeMap.isEmpty()) {
            return result;
        }
        return execute(new ArrayList<>(keyTimeMap.keySet()), (pipeline, pipelineKeys) -> {
            for (String key : pipelineKeys) {
                pipeline.expire(key, keyTimeMap.get(key));
            }
        }, (resultMap) -> {
            if (resultMap == null || resultMap.isEmpty()) {
                return result;
            }
            for (Map.Entry<String, Object> entry : resultMap.entrySet()) {
                String key = entry.getKey();
                Object object = entry.getValue();
                if (object == null) {
                    continue;
                }
                if (isNeedRetries(object)) {
                    Long value = expire(key, keyTimeMap.get(key));
                    if (value != null) {
                        result.put(key, value);
                    }
                } else {
                    result.put(key, Long.valueOf(object.toString()));
                }
            }
            return result;
        });
    }

    @SuppressWarnings("unchecked")
    public Map<String, Map<String, String>> mhgetAll(List<String> keys) {
        Map<String, Map<String, String>> result = new HashMap<>();
        if (keys == null || keys.isEmpty()) {
            return result;
        }
        return execute(keys, (pipeline, pipelineKeys) -> {
            for (String key : pipelineKeys) {
                pipeline.hgetAll(key);
            }
        }, (resultMap) -> {
            if (resultMap == null || resultMap.isEmpty()) {
                return result;
            }
            for (Map.Entry<String, Object> entry : resultMap.entrySet()) {
                String key = entry.getKey();
                Object object = entry.getValue();
                if (object == null) {
                    continue;
                }
                if (isNeedRetries(object)) {
                    Map<String, String> hgetAllMap = hgetAll(key);
                    if (hgetAllMap != null) {
                        result.put(key, hgetAllMap);
                    }
                } else {
                    result.put(key, (Map<String, String>) object);
                }
            }
            return result;
        });
    }

    public Map<String, String> mhmset(final Map<String, Map<String, String>> keyValueMap) {
        Map<String, String> result = new HashMap<>();
        if (keyValueMap == null || keyValueMap.isEmpty()) {
            return result;
        }
        return execute(new ArrayList<>(keyValueMap.keySet()), (pipeline, pipelineKeys) -> {
            for (String key : pipelineKeys) {
                pipeline.hmset(key, keyValueMap.get(key));
            }
        }, (resultMap) -> {
            if (resultMap == null || resultMap.isEmpty()) {
                return result;
            }
            for (Map.Entry<String, Object> entry : resultMap.entrySet()) {
                String key = entry.getKey();
                Object object = entry.getValue();
                if (object == null) {
                    continue;
                }
                if (isNeedRetries(object)) {
                    Map<String, String> value = keyValueMap.get(key);
                    if (value != null) {
                        result.put(key, hmset(key, value));
                    }
                } else {
                    result.put(key, object.toString());
                }
            }
            return result;
        });
    }

    /**
     * 批量赋值
     * @param keyValueMap key和值map
     * @param second 过期时间，单位s
     * @return Map<String, String> key和对应的操作结果，若操作结果为null则不会放入map
     */
    public Map<String, String> mset(final Map<String, String> keyValueMap, Integer second) {
        Map<String, String> result = new HashMap<>();
        if (keyValueMap == null || keyValueMap.isEmpty()) {
            return result;
        }
        return execute(new ArrayList<>(keyValueMap.keySet()), (pipeline, pipelineKeys) -> {
            for (String key : pipelineKeys) {
                pipeline.setex(key, second, keyValueMap.get(key));
            }
        }, (resultMap) -> {
            if (resultMap == null || resultMap.isEmpty()) {
                return result;
            }
            for (Map.Entry<String, Object> entry : resultMap.entrySet()) {
                String key = entry.getKey();
                Object object = entry.getValue();
                if (object == null) {
                    continue;
                }
                if (isNeedRetries(object)) {
                    String value = keyValueMap.get(key);
                    if (value != null) {
                        result.put(key, set(key, value));
                    }
                } else {
                    result.put(key, object.toString());
                }
            }
            return result;
        });
    }

    /**
     * 批量删除
     * @param keys key集合
     * @return Map<String, Long> key和对应的操作结果，若操作结果为null则不会放入map
     */
    public Map<String, Long> mdel(final List<String> keys) {
        Map<String, Long> result = new HashMap<>();
        if (keys == null || keys.isEmpty()) {
            return result;
        }
        return execute(keys, (pipeline, pipelineKeys) -> {
            for (String key : pipelineKeys) {
                pipeline.del(key);
            }
        }, (resultMap) -> {
            if (resultMap == null || resultMap.isEmpty()) {
                return result;
            }
            for (Map.Entry<String, Object> entry : resultMap.entrySet()) {
                String key = entry.getKey();
                Object object = entry.getValue();
                if (object == null) {
                    continue;
                }
                if (isNeedRetries(object)) {
                    Long value = del(key);
                    if (value != null) {
                        result.put(key, value);
                    }
                } else {
                    result.put(key, Long.valueOf(object.toString()));
                }
            }
            return result;
        });
    }



    @SuppressWarnings("unchecked")
    public Map<String, Set<String>> mzrangeByScore(final List<String> keys, final double min, final double max) {
        Map<String, Set<String>> result = new HashMap<>();
        if (keys == null || keys.isEmpty()) {
            return result;
        }
        return execute(keys, (pipeline, pipelineKeys) -> {
            for (String key : pipelineKeys) {
                pipeline.zrangeByScore(key, min, max);
            }
        }, (resultMap) -> {
            if (resultMap == null || resultMap.isEmpty()) {
                return result;
            }
            for (Map.Entry<String, Object> entry : resultMap.entrySet()) {
                String key = entry.getKey();
                Object object = entry.getValue();
                if (object == null) {
                    continue;
                }
                if (isNeedRetries(object)) {
                    Set<String> value = zrangeByScore(key, min, max);
                    if (value != null) {
                        result.put(key, value);
                    }
                } else {
                    result.put(key, (Set<String>) object);
                }
            }
            return result;
        });
    }



    @SuppressWarnings("unchecked")
    public Map<String, Set<String>> mzrangeByScore(final List<String> keys, final String min, final String max) {
        Map<String, Set<String>> result = new HashMap<>();
        if (keys == null || keys.isEmpty()) {
            return result;
        }
        return execute(keys, (pipeline, pipelineKeys) -> {
            for (String key : pipelineKeys) {
                pipeline.zrangeByScore(key, min, max);
            }
        }, (resultMap) -> {
            if (resultMap == null || resultMap.isEmpty()) {
                return result;
            }
            for (Map.Entry<String, Object> entry : resultMap.entrySet()) {
                String key = entry.getKey();
                Object object = entry.getValue();
                if (object == null) {
                    continue;
                }
                if (isNeedRetries(object)) {
                    Set<String> value = zrangeByScore(key, min, max);
                    if (value != null) {
                        result.put(key, value);
                    }
                } else {
                    result.put(key, (Set<String>) object);
                }
            }
            return result;
        });
    }



    public Map<String, byte[]> mgetBytes(final List<String> keys) {
        Map<String, byte[]> result = new HashMap<>();
        if (keys == null || keys.isEmpty()) {
            return result;
        }
        return execute(keys, (pipeline, pipelineKeys) -> {
            for (String key : pipelineKeys) {
                pipeline.get(SafeEncoder.encode(key));
            }
        }, (resultMap) -> {
            if (resultMap == null || resultMap.isEmpty()) {
                return result;
            }
            for (Map.Entry<String, Object> entry : resultMap.entrySet()) {
                String key = entry.getKey();
                Object object = entry.getValue();
                if (object == null) {
                    continue;
                }
                if (isNeedRetries(object)) {
                    byte[] value = get(SafeEncoder.encode(key));
                    if (value != null) {
                        result.put(key, value);
                    }
                } else {
                    result.put(key, (byte[]) object);
                }
            }
            return result;
        });
    }

    /**
     * 批量赋值，值为byte[]
     * @param keyValueMap key和值map
     * @param second 过期时间，单位s
     * @return Map<String, String> key和对应的操作结果，若操作结果为null则不会放入map
     */
    public Map<String, String> msetBytes(final Map<String, byte[]> keyValueMap, Integer second) {
        Map<String, String> result = new HashMap<>();
        if (keyValueMap == null || keyValueMap.isEmpty()) {
            return result;
        }
        return execute(new ArrayList<>(keyValueMap.keySet()), (pipeline, pipelineKeys) -> {
            for (String key : pipelineKeys) {
                pipeline.setex(SafeEncoder.encode(key), second, keyValueMap.get(key));
            }
        }, (resultMap) -> {
            if (resultMap == null || resultMap.isEmpty()) {
                return result;
            }
            for (Map.Entry<String, Object> entry : resultMap.entrySet()) {
                String key = entry.getKey();
                Object object = entry.getValue();
                if (object == null) {
                    continue;
                }
                if (isNeedRetries(object)) {
                    byte[] value = keyValueMap.get(key);
                    if (value != null) {
                        result.put(key, set(SafeEncoder.encode(key), value));
                    }
                } else {
                    result.put(key, object.toString());
                }
            }
            return result;
        });
    }

    /**
     * 核心执行方法
     * 会抛出所有异常，自行try catch处理异常
     * @param keys key集合
     * @param pipelineCommand pipeline命名执行函数
     * @param resultFunction 结果处理函数
     * @return 返回resultFunction处理结果
     */
    public <R> R execute(List<String> keys, PipelineCommand<Pipeline, List<String>> pipelineCommand, Function<Map<String, Object>, R> resultFunction) {
        if (keys == null || keys.isEmpty()) {
            return null;
        }
        // 获取keys和JedisPool的映射关系
        Map<JedisPool, List<String>> poolKeysMap = getPoolKeyMap(keys);

        Map<String, Object> resultMap = new HashMap<>();
        for (Map.Entry<JedisPool, List<String>> entry : poolKeysMap.entrySet()) {
            List<String> subKeys = entry.getValue();
            if (subKeys == null || subKeys.isEmpty()) {
                continue;
            }
            Pipeline pipeline = null;
            try (Jedis jedis = entry.getKey().getResource()) {
                pipeline = jedis.pipelined();
                pipelineCommand.execute(pipeline, subKeys);
                List<Object> subResultList = pipeline.syncAndReturnAll();
                if (subResultList == null || subResultList.size() != subKeys.size()) {
                    throw new JedisException("pipeline request keys and return list mismatched: request keys: " + subKeys + ", return list: " + subResultList);
                }
                for (int i = 0; i < subKeys.size(); i++) {
                    resultMap.put(subKeys.get(i), subResultList.get(i));
                }
            } finally {
                if (pipeline != null) {
                    pipeline.clear();
                }
            }
        }
        return resultFunction.apply(resultMap);
    }

    /**
     * 获取key和JedisPool映射关系
     */
    private Map<JedisPool, List<String>> getPoolKeyMap(List<String> keys) {
        Map<JedisPool, List<String>> poolKeysMap = new LinkedHashMap<>();
        for (String key : keys) {
            JedisPool jedisPool = getJedisPoolFromSlot(JedisClusterCRC16.getSlot(key));
            if (poolKeysMap.containsKey(jedisPool)) {
                poolKeysMap.get(jedisPool).add(key);
            } else {
                poolKeysMap.put(jedisPool, new ArrayList<String>() {
                    {
                        add(key);
                    }
                });
            }
        }
        return poolKeysMap;
    }

    /**
     * 检查是否重试
     */
    private boolean isNeedRetries(Object obj) {
        if (obj instanceof JedisConnectionException) {
            return true;
        } else if (obj instanceof JedisRedirectionException) {
            JedisRedirectionException e = (JedisRedirectionException) obj;
            // 重定向slot 映射.
            if (e instanceof JedisMovedDataException) {
                // it rebuilds cluster's slot cache
                // recommended by Redis cluster specification
                this.connectionHandler.renewSlotCache();
                log.warn("PipelineCluster JedisMovedDataException occurred and then renewSlotCache", e);
                return true;
            } else if (e instanceof JedisAskDataException) {
                log.warn("PipelineCluster JedisAskDataException occurred", e);
                return true;
            } else {
                log.error("PipelineCluster JedisRedirectionException occurred", e);
            }
        } else if (obj instanceof Exception) {
            Exception e = (Exception) obj;
            log.error("PipelineCluster Exception occurred", e);
        }
        return false;
    }

    /**
     * 根据slog获取JedisPool
     */
    private JedisPool getJedisPoolFromSlot(int slot) {
        return cache.getSlotPool(slot);
    }

}
